# ======================================================================
# Copyright CERFACS (March 2019)
# Contributor: Adrien Suau (adrien.suau@cerfacs.fr)
#
# This software is governed by the CeCILL-B license under French law and
# abiding  by the  rules of  distribution of free software. You can use,
# modify  and/or  redistribute  the  software  under  the  terms  of the
# CeCILL-B license as circulated by CEA, CNRS and INRIA at the following
# URL "http://www.cecill.info".
#
# As a counterpart to the access to  the source code and rights to copy,
# modify and  redistribute granted  by the  license, users  are provided
# only with a limited warranty and  the software's author, the holder of
# the economic rights,  and the  successive licensors  have only limited
# liability.
#
# In this respect, the user's attention is drawn to the risks associated
# with loading,  using, modifying and/or  developing or reproducing  the
# software by the user in light of its specific status of free software,
# that  may mean  that it  is complicated  to manipulate,  and that also
# therefore  means that  it is reserved for  developers and  experienced
# professionals having in-depth  computer knowledge. Users are therefore
# encouraged  to load and  test  the software's  suitability as  regards
# their  requirements  in  conditions  enabling  the  security  of their
# systems  and/or  data to be  ensured and,  more generally,  to use and
# operate it in the same conditions as regards security.
#
# The fact that you  are presently reading this  means that you have had
# knowledge of the CeCILL-B license and that you accept its terms.
# ======================================================================

import warnings

import numpy
import qaths.utils.constants
import scipy.linalg
import scipy.sparse


def kron(start: numpy.ndarray, *others: numpy.ndarray):
    if len(others) == 0:
        return start
    elif len(others) == 1:
        return numpy.kron(start, others[0])

    return numpy.kron(start, kron(*others))


def expm(matrix: numpy.ndarray):
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", category=scipy.sparse.SparseEfficiencyWarning)
        return scipy.linalg.expm(matrix)


def normalise(vector, norm=2) -> numpy.ndarray:
    if not isinstance(vector, numpy.ndarray):
        vector = numpy.array(vector)
    return vector / numpy.linalg.norm(vector, norm)


def special_unitary(unitary: numpy.ndarray) -> numpy.ndarray:
    determinant = numpy.linalg.det(unitary)
    # determinant has an amplitude of 1, only the phase is important here.
    phi = numpy.angle(determinant)
    n = unitary.shape[0]
    coefficient = numpy.exp(-1.0j * phi / n)
    assert abs(numpy.linalg.det(coefficient * unitary) - 1.0) < 1e-7
    return coefficient * unitary


def generate_random_hermitian_permutation(n: int, include_conj: bool = False):
    r"""Generate a symmetrical permutation.

    The returned permutation :math:`M` is hermitian. Because a permutation matrix
    is only composed of real numbers, the returned permutation represent also
    a symmetric matrix.

    The permutation is returned as a 1D array of integers. If the entry :math:`i`
    contains the value :math:`j`, this means that:

        - :math:`M_{ij}` = 1
        - :math:`M_{ji}` = 1 (symmetry)

    :param n: Number of qubits. The returned permutation will represent
        a :math:`\left(2^n, 2^n\right)` matrix.
    :param include_conj: A flag that should be set if we also want to return an
        array with flags that indicate whether or not the associated value in the
        permutation should be conjugated.
    """
    indices = list(range(2 ** n))
    hermitian_permutation = numpy.zeros((2 ** n,), dtype=numpy.int)
    signs = numpy.zeros((2 ** n,), dtype=numpy.bool)

    while indices:
        i = indices.pop(numpy.random.randint(len(indices)))
        j = indices.pop(numpy.random.randint(len(indices)))
        hermitian_permutation[i] = j
        hermitian_permutation[j] = i
        if include_conj:
            signs[i] = True
            signs[j] = False
    if include_conj:
        return hermitian_permutation, signs
    return hermitian_permutation


def reverse_normalised_kronecker(vec: numpy.ndarray, *sizes: int):
    """Reverse a kronecker product of normalised vectors.

    :param vec: The full vector, result of the kronecker product.
    :param sizes: The sizes of the vectors that have been composed with kronecker
        product to give vec.
    :return:
    """
    assert vec.size == numpy.product(sizes)
    if len(sizes) == 1:
        return [normalise(vec)]

    first_size = sizes[0]
    remaining_size = numpy.product(sizes[1:]).astype(int)

    matrix = vec.reshape((first_size, remaining_size))
    first_vec, other_vecs = None, None
    for column in range(remaining_size):
        if numpy.linalg.norm(matrix[:, column]) > qaths.utils.constants.UNDER_IS_ZERO:
            first_vec = normalise(matrix[:, column])
            break
    for row in range(first_size):
        if numpy.linalg.norm(matrix[row, :]) > qaths.utils.constants.UNDER_IS_ZERO:
            other_vecs = reverse_normalised_kronecker(matrix[row, :], *sizes[1:])
            break
    if first_vec is None or other_vecs is None:
        raise RuntimeError(
            "Can't reverse this kronecker product because one of the "
            + "given dimension is the zero vector (and so cannot be "
            + "normalised). A norm below {}".format(qaths.utils.constants.UNDER_IS_ZERO)
            + " is considered as 0."
        )
    return [first_vec] + other_vecs


def next_power_of_2(n: int) -> int:
    r"""Return :math:`2^i` such that :math:`2^{i-1} < n \leqslant 2^i`.

    :param n: The number to upper bound by a power of :math:`2`.
    :return: :math:`2^i` such that :math:`2^{i-1} < n \leqslant 2^i`.
    """
    return 1 if n == 0 else 2 ** ((n - 1).bit_length())


def complete_by_zeros(M: scipy.sparse.spmatrix):
    r"""Complete :math:`M` with 0s until :math:`M is square and of dimension\
    :math:`\left(2^i, 2^i\right)`.

    :param M: The matrix of size :math:`\left( m, n\right)` to complete.
    :return: a square matrix :math:`M' = \begin{pmatrix}M & 0 \\ 0 & 0 \\\end{pmatrix}`
        with `next_power_of_2` :math:`\left( \max \left\{ m, n \right\} \right)`
        columns and lines.

    """
    assert scipy.sparse.issparse(M), "Given matrix should be sparse."
    m, n = M.shape
    next_power_of_two = next_power_of_2(max(m, n))
    return scipy.sparse.vstack(
        (
            scipy.sparse.hstack(
                (M, scipy.sparse.coo_matrix((m, next_power_of_two - n), dtype=M.dtype))
            ),
            scipy.sparse.coo_matrix(
                (next_power_of_two - m, next_power_of_two), dtype=M.dtype
            ),
        )
    )
